# -*- coding: utf-8 -*-
"""
Created on Tue May 16 22:38:27 2023

@author: chendu
"""
import os

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from models.CompNet import Net

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                         download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                        download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = trainset.classes

epochs = 10
learning_rate = 0.001
# net = SimpleResNets.CA_ResNet10()
# net = denseNet.denseNet(10)
net, save_name = Net(100)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)


def train(epoch, log_interval=1000):
    # Set model to training mode
    net.train()

    running_loss = 0.0
    # Loop over each batch from the training set
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize)
        # Pass data through the network
        outputs = net(inputs)
        # Calculate loss
        loss = criterion(outputs, labels)
        # Backpropagate
        loss.backward()
        # Update weights
        optimizer.step()  # w - alpha * dL / dw
        # print statistics
        running_loss += loss.item()
        if i % log_interval == log_interval - 1:  # print every 2000 mini-batches
            print(f'[{epoch}, {i + 1:5d}] loss: {running_loss / log_interval:.3f}')
            running_loss = 0.0


def validate(loss_vector, accuracy_vector):
    net.eval()

    total = 0
    val_loss, correct = 0, 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            # calculate outputs by running images through the network
            outputs = net(images)
            # Calculate loss
            val_loss += criterion(outputs, labels).data.item()
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct // total
    accuracy_vector.append(accuracy)
    val_loss /= len(testloader)
    loss_vector.append(val_loss)
    print(f'Accuracy of the network on the 10000 test images: {accuracy} %')
    print('Average loss: {:.4f}'.format(val_loss))


def main():
    print(f"net: {save_name}")
    lossv, accv = [], []
    for epoch in range(1, epochs + 1):
        train(epoch)
        validate(lossv, accv)

    # save model
    PATH = './cifar100_{}.pth'.format(save_name)
    torch.save(net.state_dict(), PATH)

    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    # again no gradients needed
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predictions = torch.max(outputs, 1)
            # collect the correct predictions for each class
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, epochs + 1), lossv)
    plt.title('validation loss')
    if not os.path.exists('imgs'):
        os.mkdir('imgs')
    plt.savefig('imgs/lossv_{}.png'.format(save_name))
    plt.savefig('imgs/lossv_{}.eps'.format(save_name))

    plt.figure(figsize=(5, 3))
    plt.plot(np.arange(1, epochs + 1), accv)
    plt.title('validation accuracy');
    if not os.path.exists('imgs'):
        os.mkdir('imgs')
    plt.savefig('imgs/accv_{}.png'.format(save_name))
    plt.savefig('imgs/accv_{}.eps'.format(save_name))
    plt.show()

    print(f"net: {save_name}")


if __name__ == '__main__':
    main()
